# -*- coding: utf-8 -*-
"""
Created on Tue Apr  2 20:10:00 2019

@author: xiang_yaobing
"""
import numpy as np
from tensorflow.keras.models import load_model
import os 
from PIL import Image
import tensorflow as tf
def load_data(path1, path2, path3, path4):
#读入文件列表，path代表2个类别的文件路径
    filelist1 = [os.path.join(path1, f) for f in os.listdir(path1)]
    filelist2 = [os.path.join(path2, f) for f in os.listdir(path2)]
    filelist3 = [os.path.join(path3, f) for f in os.listdir(path3)]
    filelist4 = [os.path.join(path4, f) for f in os.listdir(path4)]
    x_test = []
    y_test = []
    n1 = len(filelist1)
    n2 = len(filelist3)
    for i in range(n1):
        im1 = np.array(Image.open(filelist1[i]))#读图像
        im2 = np.array(Image.open(filelist2[i]))#读图像
        im = np.reshape(np.array([im1,im2]),(im1.shape[0],im1.shape[1],2))
        #im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    for j in range(n2):
        im1 = np.array(Image.open(filelist3[i]))
        im2 = np.array(Image.open(filelist4[i]))#读图像
        im = np.reshape(np.array([im1,im2]),(im1.shape[0],im1.shape[1],2))
        #print(img)
        #im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    x_test = np.array(x_test)
    #自己造标签 总共2类，所以标签是01
    y_test = np.zeros((n1 + n2), dtype=int)
    for i in range(n1): 
        y_test[i] = 0#不存在疏散星团
    for i in range(n2):
        y_test[n1 + i] = 1
    return x_test, y_test

def check_lv_data(data,model_path):
#读入文件列表，path代表类别的文件路径,l代表位置， v代表速度
    data = tf.cast(data, tf.float32)
    model = load_model(model_path)
    print(model.summary())
    result = np.array(model.predict_classes(np.array(data)))
    probability = np.array(model.predict(np.array(data)))
    return result, probability

#%% 
if __name__ == "__main__":
    
    path1 = 'E:\\疏散星团论文修改处理\\testset\\ClusterKDEpicture\\cluster\\loc0.1\\'
    path2 = 'E:\\疏散星团论文修改处理\\testset\\ClusterKDEpicture\\cluster\\loc\\'
    path3 = 'E:\\疏散星团论文修改处理\\testset\\noClusterKDEpicture\\no_cluster\\loc0.1\\'
    path4 = 'E:\\疏散星团论文修改处理\\testset\\noClusterKDEpicture\\no_cluster\\loc\\'
    
#    path1 = 'E:\\疏散星团论文修改处理\\NoClusterKDEpicture\\csv_dataset_no_cluster\\loc0.1\\'
#    path2 = 'E:\\疏散星团论文修改处理\\NoClusterKDEpicture\\csv_dataset_no_cluster\\loc\\'
#    path3 = 'E:\\疏散星团论文修改处理\\ClusterKDEpicture\\csv_dataset\\loc0.1\\'
#    path4 = 'E:\\疏散星团论文修改处理\\ClusterKDEpicture\\csv_dataset\\loc\\'
    data, labels = load_data(path1, path2, path3, path4)
    #file_path = 'E:\\天文\\new mock data v2_\\new_data_5_1\\no_cluster\\pm_fig'
    #%%
    model_path = './/cnn_model_loc_v1.h5'
    data, label = load_data(path1, path2, path3, path4)
    result, prob = check_lv_data(data, model_path)
    #out_result = [name, prob, result]

            
